Hui, J., 2021. GAN — Wasserstein GAN & WGAN-GP [WWW Document]. Medium. URL https://jonathan-hui.medium.com/gan-wasserstein-gan-wgan-gp-6a1a2aa1b490 (accessed 1.31.23).

Arjovsky, M., Chintala, S., Bottou, L., 2017. Wasserstein GAN [WWW Document]. arXiv.org. URL https://arxiv.org/abs/1701.07875v3 (accessed 2.1.23).

10. What is the 'W' in WGAN?¶

W here is Wasserstein named after mathematical concept wasserstein distance (Arjovsky et al., 2017)

10.1 What is Wasserstein loss or Earth-Mover(EM) distance¶

Wasserstein distance is the minimum cost of transporting mass in converting the data distribution q to the data distribution p.

From the WGAN paper (Arjovsky et al., 2017) : The intuition of the formula above is measuring the distance between 2 distribution

Simplified formula: D(x) - D(G(z))

D(x) is the critic's output for a real instance. G(z) is the generator's output when given noise z. D(G(z)) is the critic's output for a fake instance

It takes confidence of the discriminator's prediction on a real image and then deducting the confidennce of the discrinator's prediction for a fake image

10.2 How is WGAN better¶

From the Wasserstein GAN (Hui, 2021): WGAN first looked at KL and JS divergence and realised that if the 2nd Gaussian distribution Q becomes too large , gradient of divergency diminishes and generator learns nothing.

KL uses JS divergence uses the KL diverge to calculate a more normalized score which gives a value between 0 to 1 . On top of that the measure of the difference between the real and fake distribution is symmetric. People also call the JS divergence the average KL divergence since it multiplies by half of pretty much the KL divergence formula.

The figure above from the original WGAN paper (Arjovsky et al., 2017) shows an optimal discriminator trying to differentiate two Gaussians (aka 2 types of images) using the original loss function Binary Cross entropy = abs(y_pred - y_true). However as it trains , the discriminator draws the boundaries but also result in a vanishing gradient. This is a huge problem because without the critic's gradient being fed to the generator , training process stops. However, the Wasserstein loss appears to provide a clean gradient for the generator to continue learning

10.2.1 Key benefits:¶

  • removes vanishing gradient problem
  • loss metric helps with generator's convergence
  • improves training stability
In [3]:
# pip install visualkeras
Requirement already satisfied: visualkeras in c:\users\lg\.conda\envs\gpu_env\lib\site-packages (0.0.2)Note: you may need to restart the kernel to use updated packages.

Requirement already satisfied: aggdraw>=1.3.11 in c:\users\lg\.conda\envs\gpu_env\lib\site-packages (from visualkeras) (1.3.15)
Requirement already satisfied: pillow>=6.2.0 in c:\users\lg\.conda\envs\gpu_env\lib\site-packages (from visualkeras) (9.2.0)
Requirement already satisfied: numpy>=1.18.1 in c:\users\lg\.conda\envs\gpu_env\lib\site-packages (from visualkeras) (1.23.4)
In [9]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import pandas as pd
import visualkeras
from keras.utils import to_categorical
In [2]:
IMG_SHAPE = (32,32,3)
BATCH_SIZE = 512

# Size of the noise vector
noise_dim = 128

cifar10 = keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
print(f"Number of examples: {len(train_images)}")
print(f"Shape of the images in the dataset: {train_images.shape[1:]}")

# Reshape each sample to (32,32,3) and normalize the pixel values in the [-1, 1] range
train_images = train_images.reshape(train_images.shape[0], *IMG_SHAPE).astype("float32")
train_images = (train_images - 127.5) / 127.5
Number of examples: 50000
Shape of the images in the dataset: (32, 32, 3)
In [7]:
def conv_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="same",
    use_bias=True,
    use_bn=False,
    use_dropout=False,
    drop_value=0.5,
):
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_discriminator_model():
    img_input = layers.Input(shape=IMG_SHAPE)
    # Zero pad the input to make the input images size to (32, 32, 3).
    x = layers.ZeroPadding2D((2, 2))(img_input)
    x = conv_block(
        x,
        64,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        use_bias=True,
        activation=layers.LeakyReLU(0.2),
        use_dropout=False,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        128,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        256,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        512,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=False,
        drop_value=0.3,
    )

    x = layers.Flatten()(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(1)(x)

    d_model = keras.models.Model(img_input, x, name="discriminator")
    return d_model


d_model = get_discriminator_model()
d_model.summary()
Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 zero_padding2d (ZeroPadding  (None, 36, 36, 3)        0         
 2D)                                                             
                                                                 
 conv2d (Conv2D)             (None, 18, 18, 64)        4864      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 18, 18, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 9, 9, 128)         204928    
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 9, 9, 128)         0         
                                                                 
 dropout (Dropout)           (None, 9, 9, 128)         0         
                                                                 
 conv2d_2 (Conv2D)           (None, 5, 5, 256)         819456    
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 5, 5, 256)         0         
                                                                 
 dropout_1 (Dropout)         (None, 5, 5, 256)         0         
                                                                 
 conv2d_3 (Conv2D)           (None, 3, 3, 512)         3277312   
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 3, 3, 512)         0         
                                                                 
 flatten (Flatten)           (None, 4608)              0         
                                                                 
 dropout_2 (Dropout)         (None, 4608)              0         
                                                                 
 dense (Dense)               (None, 1)                 4609      
                                                                 
=================================================================
Total params: 4,311,169
Trainable params: 4,311,169
Non-trainable params: 0
_________________________________________________________________
In [8]:
import visualkeras
visualkeras.layered_view(d_model, legend=True)
Out[8]:
In [9]:
def upsample_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    up_size=(2, 2),
    padding="same",
    use_bn=False,
    use_bias=True,
    use_dropout=False,
    drop_value=0.3,
):
    x = layers.UpSampling2D(up_size)(x)
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)

    if use_bn:
        x = layers.BatchNormalization()(x)

    if activation:
        x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_generator_model():
    noise = layers.Input(shape=(noise_dim,))
    x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Reshape((4, 4, 256))(x)
    x = upsample_block(
        x,
        128,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    )
    x = upsample_block(
        x,
        64,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    )
    x = upsample_block(
        x, 3, layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=True
    )
    # At this point, we have an output which has the same shape as the input, (32, 32, 1).
    # We will use a Cropping2D layer to make it (28, 28, 1).
    # x = layers.Cropping2D((2, 2))(x)
    # remove cropping layer?

    g_model = keras.models.Model(noise, x, name="generator")
    return g_model


g_model = get_generator_model()
g_model.summary()
Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 128)]             0         
                                                                 
 dense_1 (Dense)             (None, 4096)              524288    
                                                                 
 batch_normalization (BatchN  (None, 4096)             16384     
 ormalization)                                                   
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 4096)              0         
                                                                 
 reshape (Reshape)           (None, 4, 4, 256)         0         
                                                                 
 up_sampling2d (UpSampling2D  (None, 8, 8, 256)        0         
 )                                                               
                                                                 
 conv2d_4 (Conv2D)           (None, 8, 8, 128)         294912    
                                                                 
 batch_normalization_1 (Batc  (None, 8, 8, 128)        512       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 8, 8, 128)         0         
                                                                 
 up_sampling2d_1 (UpSampling  (None, 16, 16, 128)      0         
 2D)                                                             
                                                                 
 conv2d_5 (Conv2D)           (None, 16, 16, 64)        73728     
                                                                 
 batch_normalization_2 (Batc  (None, 16, 16, 64)       256       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_6 (LeakyReLU)   (None, 16, 16, 64)        0         
                                                                 
 up_sampling2d_2 (UpSampling  (None, 32, 32, 64)       0         
 2D)                                                             
                                                                 
 conv2d_6 (Conv2D)           (None, 32, 32, 3)         1728      
                                                                 
 batch_normalization_3 (Batc  (None, 32, 32, 3)        12        
 hNormalization)                                                 
                                                                 
 activation (Activation)     (None, 32, 32, 3)         0         
                                                                 
=================================================================
Total params: 911,820
Trainable params: 903,238
Non-trainable params: 8,582
_________________________________________________________________
In [10]:
visualkeras.layered_view(g_model, legend=True)
Out[10]:
In [5]:
class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, real_images, fake_images):
        """Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        # For each batch, we are going to perform the
        # following steps as laid out in the original paper:
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add the gradient penalty to the discriminator loss
        # 6. Return the generator and discriminator losses as a loss dictionary

        # Train the discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.
        for i in range(self.d_steps):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(random_latent_vectors, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for the real images
                real_logits = self.discriminator(real_images, training=True)

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        # Train the generator
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}
In [6]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=100, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim
        self.vmin = 0
        self.vmax = 1


    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images -= self.vmin
        generated_images /= (self.vmax - self.vmin)
        # Create a figure with 10 rows and 10 columns

            
        if (epoch % 50 == 1):
            fig, axes = plt.subplots(10,10, figsize=(10, 10))
            axes = axes.ravel()
            for i in range(self.num_img):
                img = keras.preprocessing.image.array_to_img(generated_images[i])
                axes[i].imshow(img)
                axes[i].axis("off")
            title = "Epoch " + str(epoch)
            plt.suptitle(title)
            if (epoch % 250 == 1):
              path = "./generated_img_" + title
              plt.savefig(path,dpi=300)
            plt.show()


        
        
            

10.3. WGAN test¶

In [7]:
# Instantiate the optimizer for both networks
# (learning_rate=0.0002, beta_1=0.5 are recommended)
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)

# Define the loss functions for the discriminator,
# which should be (fake_loss - real_loss).
# We will add the gradient penalty later to this loss function.
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss


# Define the loss functions for the generator.
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)


# Set the number of epochs for trainining.
epochs = 300

# Instantiate the customer `GANMonitor` Keras callback.
cbk = GANMonitor(num_img=100, latent_dim=noise_dim)

# Get the wgan model
wgan = WGAN(
    discriminator=d_model,
    generator=g_model,
    latent_dim=noise_dim,
    discriminator_extra_steps=3,
)

# Compile the wgan model
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)

# Start training
hist = wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])
Epoch 1/300
98/98 [==============================] - 57s 493ms/step - d_loss: -7.3406 - g_loss: 21.8229
Epoch 2/300
98/98 [==============================] - ETA: 0s - d_loss: -4.1790 - g_loss: 12.7049
98/98 [==============================] - 50s 514ms/step - d_loss: -4.1810 - g_loss: 12.6020
Epoch 3/300
98/98 [==============================] - 48s 486ms/step - d_loss: -3.7075 - g_loss: 5.0171
Epoch 4/300
98/98 [==============================] - 48s 487ms/step - d_loss: -3.4194 - g_loss: 1.5335
Epoch 5/300
98/98 [==============================] - 48s 488ms/step - d_loss: -3.0173 - g_loss: -6.4131
Epoch 6/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.9612 - g_loss: -1.9092
Epoch 7/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.9240 - g_loss: -1.0536
Epoch 8/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.8378 - g_loss: -5.3020
Epoch 9/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.8439 - g_loss: -2.2552
Epoch 10/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.7490 - g_loss: -3.3330
Epoch 11/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.7314 - g_loss: -3.8568
Epoch 12/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.6351 - g_loss: -7.1157
Epoch 13/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.5093 - g_loss: -8.9468
Epoch 14/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.4658 - g_loss: -4.1055
Epoch 15/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.1847 - g_loss: -8.2322
Epoch 16/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.2778 - g_loss: -9.5953
Epoch 17/300
98/98 [==============================] - 48s 488ms/step - d_loss: -2.3049 - g_loss: -8.6796
Epoch 18/300
98/98 [==============================] - 48s 492ms/step - d_loss: -2.2723 - g_loss: -10.8831
Epoch 19/300
98/98 [==============================] - 47s 480ms/step - d_loss: -1.9372 - g_loss: -7.9568
Epoch 20/300
98/98 [==============================] - 47s 483ms/step - d_loss: -1.9333 - g_loss: -6.8688
Epoch 21/300
98/98 [==============================] - 47s 480ms/step - d_loss: -1.9050 - g_loss: -1.9870
Epoch 22/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.7534 - g_loss: -4.2289
Epoch 23/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.7087 - g_loss: 4.0505
Epoch 24/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.8705 - g_loss: -1.0945
Epoch 25/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.7655 - g_loss: -1.0792
Epoch 26/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.8105 - g_loss: -0.6885
Epoch 27/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.8894 - g_loss: 0.1071
Epoch 28/300
98/98 [==============================] - 47s 478ms/step - d_loss: -1.7708 - g_loss: -0.8322
Epoch 29/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.7459 - g_loss: -1.0736
Epoch 30/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.6770 - g_loss: 1.0342
Epoch 31/300
98/98 [==============================] - 47s 478ms/step - d_loss: -1.7013 - g_loss: 0.0665
Epoch 32/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.6664 - g_loss: -3.6770
Epoch 33/300
98/98 [==============================] - 47s 479ms/step - d_loss: -1.7129 - g_loss: -0.9532
Epoch 34/300
98/98 [==============================] - 48s 492ms/step - d_loss: -1.6525 - g_loss: 2.0594
Epoch 35/300
98/98 [==============================] - 48s 487ms/step - d_loss: -1.5863 - g_loss: -2.5276
Epoch 36/300
98/98 [==============================] - 48s 487ms/step - d_loss: -1.6913 - g_loss: -1.3285
Epoch 37/300
98/98 [==============================] - 48s 487ms/step - d_loss: -1.6211 - g_loss: -2.1618
Epoch 38/300
98/98 [==============================] - 47s 481ms/step - d_loss: -1.6025 - g_loss: 1.6889
Epoch 39/300
98/98 [==============================] - 47s 477ms/step - d_loss: -1.6253 - g_loss: 0.6202
Epoch 40/300
98/98 [==============================] - 47s 477ms/step - d_loss: -1.6837 - g_loss: 0.0393
Epoch 41/300
98/98 [==============================] - 47s 477ms/step - d_loss: -1.6113 - g_loss: -4.3944
Epoch 42/300
98/98 [==============================] - 48s 485ms/step - d_loss: -1.5834 - g_loss: -1.0026
Epoch 43/300
98/98 [==============================] - 48s 487ms/step - d_loss: -1.5120 - g_loss: 1.5168
Epoch 44/300
98/98 [==============================] - 48s 487ms/step - d_loss: -1.6232 - g_loss: 2.1421
Epoch 45/300
98/98 [==============================] - 48s 488ms/step - d_loss: -1.5178 - g_loss: 0.2262
Epoch 46/300
98/98 [==============================] - 47s 478ms/step - d_loss: -1.4970 - g_loss: -1.3411
Epoch 47/300
98/98 [==============================] - 45s 462ms/step - d_loss: -1.5803 - g_loss: -1.3414
Epoch 48/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.6683 - g_loss: -2.5523
Epoch 49/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.5444 - g_loss: 0.0131
Epoch 50/300
98/98 [==============================] - 45s 455ms/step - d_loss: -1.5371 - g_loss: -1.2098
Epoch 51/300
98/98 [==============================] - 45s 455ms/step - d_loss: -1.6387 - g_loss: 0.4212
Epoch 52/300
98/98 [==============================] - ETA: 0s - d_loss: -1.5250 - g_loss: 0.2688
98/98 [==============================] - 47s 478ms/step - d_loss: -1.5237 - g_loss: 0.2351
Epoch 53/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.5084 - g_loss: -0.9236
Epoch 54/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.5153 - g_loss: -1.2109
Epoch 55/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.5129 - g_loss: 0.2403
Epoch 56/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.5605 - g_loss: 1.3756
Epoch 57/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.4074 - g_loss: -1.5645
Epoch 58/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.4712 - g_loss: -2.6132
Epoch 59/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.4775 - g_loss: 3.8631
Epoch 60/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.5548 - g_loss: 0.8922
Epoch 61/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.4649 - g_loss: -4.5824
Epoch 62/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.5285 - g_loss: -2.4317
Epoch 63/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.4304 - g_loss: -2.4881
Epoch 64/300
98/98 [==============================] - 45s 456ms/step - d_loss: -1.4473 - g_loss: -1.1217
Epoch 65/300
98/98 [==============================] - 45s 461ms/step - d_loss: -1.5040 - g_loss: 2.7232
Epoch 66/300
98/98 [==============================] - 45s 461ms/step - d_loss: -1.4376 - g_loss: -3.5250
Epoch 67/300
98/98 [==============================] - 45s 457ms/step - d_loss: -1.4372 - g_loss: 1.2274
Epoch 68/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.4734 - g_loss: -1.7833
Epoch 69/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.4026 - g_loss: -2.7036
Epoch 70/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3780 - g_loss: -0.3760
Epoch 71/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3971 - g_loss: 0.2833
Epoch 72/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.4545 - g_loss: 0.6909
Epoch 73/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3889 - g_loss: 2.4488
Epoch 74/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3800 - g_loss: 0.9237
Epoch 75/300
98/98 [==============================] - 45s 457ms/step - d_loss: -1.4084 - g_loss: 0.4737
Epoch 76/300
98/98 [==============================] - 45s 460ms/step - d_loss: -1.4056 - g_loss: 0.7281
Epoch 77/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3699 - g_loss: 1.4609
Epoch 78/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.4318 - g_loss: -2.5051
Epoch 79/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3597 - g_loss: -0.8964
Epoch 80/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3212 - g_loss: 0.8537
Epoch 81/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3717 - g_loss: 2.4114
Epoch 82/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3176 - g_loss: -1.2447
Epoch 83/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3894 - g_loss: 1.0681
Epoch 84/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3219 - g_loss: 1.3581
Epoch 85/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3531 - g_loss: 1.3054
Epoch 86/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.2949 - g_loss: -0.5320
Epoch 87/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3341 - g_loss: 0.9290
Epoch 88/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3099 - g_loss: 1.3935
Epoch 89/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.2745 - g_loss: 0.4241
Epoch 90/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3496 - g_loss: 1.5486
Epoch 91/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.2761 - g_loss: 1.3240
Epoch 92/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3818 - g_loss: 1.5650
Epoch 93/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3010 - g_loss: 1.3619
Epoch 94/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3310 - g_loss: 1.3199
Epoch 95/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.3286 - g_loss: 4.0010
Epoch 96/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.2649 - g_loss: 1.1841
Epoch 97/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.3311 - g_loss: 0.9100
Epoch 98/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3023 - g_loss: 0.6819
Epoch 99/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3255 - g_loss: 0.5858
Epoch 100/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.3475 - g_loss: 2.6210
Epoch 101/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.2194 - g_loss: 3.1946
Epoch 102/300
98/98 [==============================] - ETA: 0s - d_loss: -1.3056 - g_loss: 1.8696
98/98 [==============================] - 47s 482ms/step - d_loss: -1.3120 - g_loss: 1.8494
Epoch 103/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.2848 - g_loss: 1.4237
Epoch 104/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.3256 - g_loss: 1.1203
Epoch 105/300
98/98 [==============================] - 45s 455ms/step - d_loss: -1.2511 - g_loss: 1.0639
Epoch 106/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2916 - g_loss: -0.7054
Epoch 107/300
98/98 [==============================] - 45s 455ms/step - d_loss: -1.1969 - g_loss: 0.4681
Epoch 108/300
98/98 [==============================] - 45s 460ms/step - d_loss: -1.2933 - g_loss: 1.3242
Epoch 109/300
98/98 [==============================] - 45s 460ms/step - d_loss: -1.2012 - g_loss: 2.2679
Epoch 110/300
98/98 [==============================] - 45s 460ms/step - d_loss: -1.3089 - g_loss: 0.6502
Epoch 111/300
98/98 [==============================] - 45s 460ms/step - d_loss: -1.1943 - g_loss: 1.5444
Epoch 112/300
98/98 [==============================] - 45s 460ms/step - d_loss: -1.2381 - g_loss: 2.6213
Epoch 113/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.2571 - g_loss: 3.5570
Epoch 114/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2228 - g_loss: 1.4997
Epoch 115/300
98/98 [==============================] - 45s 455ms/step - d_loss: -1.2060 - g_loss: 2.0576
Epoch 116/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.2086 - g_loss: 1.5745
Epoch 117/300
98/98 [==============================] - 45s 455ms/step - d_loss: -1.1207 - g_loss: 2.1902
Epoch 118/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2450 - g_loss: 2.0470
Epoch 119/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2803 - g_loss: 1.1288
Epoch 120/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2472 - g_loss: 4.5549
Epoch 121/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2029 - g_loss: 3.3712
Epoch 122/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2363 - g_loss: 1.9598
Epoch 123/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1100 - g_loss: 1.3313
Epoch 124/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2277 - g_loss: 1.6718
Epoch 125/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2280 - g_loss: 0.8758
Epoch 126/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1831 - g_loss: 2.6083
Epoch 127/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.2510 - g_loss: 2.0048
Epoch 128/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.2414 - g_loss: 2.6791
Epoch 129/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2265 - g_loss: 1.6053
Epoch 130/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2032 - g_loss: 1.5942
Epoch 131/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1996 - g_loss: 0.3605
Epoch 132/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.1722 - g_loss: 1.2500
Epoch 133/300
98/98 [==============================] - 45s 458ms/step - d_loss: -1.2058 - g_loss: 2.9096
Epoch 134/300
98/98 [==============================] - 45s 458ms/step - d_loss: -1.2217 - g_loss: 2.6260
Epoch 135/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1867 - g_loss: 2.9725
Epoch 136/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.2068 - g_loss: 3.9829
Epoch 137/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1420 - g_loss: 3.3342
Epoch 138/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1942 - g_loss: 2.5772
Epoch 139/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1858 - g_loss: 4.1331
Epoch 140/300
98/98 [==============================] - 46s 465ms/step - d_loss: -1.1805 - g_loss: 1.7865
Epoch 141/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1272 - g_loss: 1.8605
Epoch 142/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1219 - g_loss: 1.3973
Epoch 143/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.0834 - g_loss: 0.9771
Epoch 144/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1768 - g_loss: 1.1102
Epoch 145/300
98/98 [==============================] - 45s 457ms/step - d_loss: -1.1459 - g_loss: 1.1065
Epoch 146/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1261 - g_loss: 1.6852
Epoch 147/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1674 - g_loss: 2.6545
Epoch 148/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0942 - g_loss: 1.8003
Epoch 149/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1106 - g_loss: 3.0615
Epoch 150/300
98/98 [==============================] - 45s 456ms/step - d_loss: -1.1636 - g_loss: 5.3379
Epoch 151/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1802 - g_loss: 3.7367
Epoch 152/300
98/98 [==============================] - ETA: 0s - d_loss: -1.1578 - g_loss: 2.6614
98/98 [==============================] - 48s 488ms/step - d_loss: -1.1581 - g_loss: 2.6190
Epoch 153/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1354 - g_loss: 2.7578
Epoch 154/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1498 - g_loss: 3.7828
Epoch 155/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1098 - g_loss: 3.3301
Epoch 156/300
98/98 [==============================] - 45s 460ms/step - d_loss: -1.1638 - g_loss: 2.1495
Epoch 157/300
98/98 [==============================] - 45s 464ms/step - d_loss: -1.1111 - g_loss: 0.8604
Epoch 158/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.2022 - g_loss: 1.7162
Epoch 159/300
98/98 [==============================] - 46s 465ms/step - d_loss: -1.1373 - g_loss: 2.4267
Epoch 160/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1175 - g_loss: 2.6139
Epoch 161/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1568 - g_loss: 3.1018
Epoch 162/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1616 - g_loss: 0.4112
Epoch 163/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.2065 - g_loss: 2.1675
Epoch 164/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1737 - g_loss: 3.2349
Epoch 165/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.0799 - g_loss: 0.6955
Epoch 166/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.1050 - g_loss: -0.3816
Epoch 167/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.0218 - g_loss: 1.3857
Epoch 168/300
98/98 [==============================] - 46s 466ms/step - d_loss: -1.0829 - g_loss: 2.3473
Epoch 169/300
98/98 [==============================] - 45s 454ms/step - d_loss: -1.0777 - g_loss: 2.8695
Epoch 170/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0733 - g_loss: 2.7384
Epoch 171/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0903 - g_loss: 2.4654
Epoch 172/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0632 - g_loss: 2.6664
Epoch 173/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0899 - g_loss: 4.2382
Epoch 174/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0500 - g_loss: 1.8510
Epoch 175/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0774 - g_loss: 3.3243
Epoch 176/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1107 - g_loss: 3.3604
Epoch 177/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1003 - g_loss: 1.8708
Epoch 178/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1401 - g_loss: 2.2696
Epoch 179/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0435 - g_loss: 0.2291
Epoch 180/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0282 - g_loss: 2.5060
Epoch 181/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1153 - g_loss: 2.8579
Epoch 182/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1233 - g_loss: 1.3153
Epoch 183/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0882 - g_loss: 2.2274
Epoch 184/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1000 - g_loss: 1.7496
Epoch 185/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1021 - g_loss: 2.4306
Epoch 186/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0370 - g_loss: 2.0342
Epoch 187/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0428 - g_loss: 2.8306
Epoch 188/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1008 - g_loss: 2.4283
Epoch 189/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1217 - g_loss: 2.7710
Epoch 190/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0630 - g_loss: 2.5978
Epoch 191/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0296 - g_loss: 3.2917
Epoch 192/300
98/98 [==============================] - 45s 460ms/step - d_loss: -1.0066 - g_loss: 2.3757
Epoch 193/300
98/98 [==============================] - 45s 456ms/step - d_loss: -1.0698 - g_loss: 0.1348
Epoch 194/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0121 - g_loss: 1.0264
Epoch 195/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9906 - g_loss: 1.3570
Epoch 196/300
98/98 [==============================] - 44s 454ms/step - d_loss: -1.0438 - g_loss: 2.2373
Epoch 197/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0402 - g_loss: -0.1310
Epoch 198/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0488 - g_loss: 3.4424
Epoch 199/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.1399 - g_loss: 4.2830
Epoch 200/300
98/98 [==============================] - 45s 459ms/step - d_loss: -1.0664 - g_loss: 3.6309
Epoch 201/300
98/98 [==============================] - 45s 460ms/step - d_loss: -0.9763 - g_loss: 4.2395
Epoch 202/300
98/98 [==============================] - ETA: 0s - d_loss: -1.0156 - g_loss: 3.0092
98/98 [==============================] - 47s 480ms/step - d_loss: -1.0251 - g_loss: 2.9691
Epoch 203/300
98/98 [==============================] - 44s 452ms/step - d_loss: -1.0047 - g_loss: 2.4707
Epoch 204/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0519 - g_loss: 2.5521
Epoch 205/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9330 - g_loss: 1.7011
Epoch 206/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0218 - g_loss: 1.6597
Epoch 207/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9699 - g_loss: 1.5330
Epoch 208/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0143 - g_loss: 1.2662
Epoch 209/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9972 - g_loss: 0.4574
Epoch 210/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0503 - g_loss: 0.5067
Epoch 211/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9710 - g_loss: 1.3070
Epoch 212/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0635 - g_loss: 0.3515
Epoch 213/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9929 - g_loss: 0.0164
Epoch 214/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0498 - g_loss: 1.6477
Epoch 215/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9700 - g_loss: 0.7667
Epoch 216/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0267 - g_loss: -0.8543
Epoch 217/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9466 - g_loss: 1.1950
Epoch 218/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0177 - g_loss: 0.7497
Epoch 219/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9565 - g_loss: 1.1094
Epoch 220/300
98/98 [==============================] - 45s 457ms/step - d_loss: -1.0675 - g_loss: 2.2064
Epoch 221/300
98/98 [==============================] - 45s 459ms/step - d_loss: -1.0404 - g_loss: 0.8887
Epoch 222/300
98/98 [==============================] - 45s 459ms/step - d_loss: -0.9818 - g_loss: 1.1712
Epoch 223/300
98/98 [==============================] - 45s 459ms/step - d_loss: -1.0265 - g_loss: 2.2391
Epoch 224/300
98/98 [==============================] - 45s 459ms/step - d_loss: -0.8715 - g_loss: 3.4849
Epoch 225/300
98/98 [==============================] - 45s 459ms/step - d_loss: -1.0107 - g_loss: 1.8880
Epoch 226/300
98/98 [==============================] - 45s 459ms/step - d_loss: -1.0036 - g_loss: 1.7467
Epoch 227/300
98/98 [==============================] - 45s 459ms/step - d_loss: -0.9965 - g_loss: 2.6096
Epoch 228/300
98/98 [==============================] - 45s 460ms/step - d_loss: -0.9812 - g_loss: 1.6501
Epoch 229/300
98/98 [==============================] - 45s 459ms/step - d_loss: -0.9855 - g_loss: 2.3651
Epoch 230/300
98/98 [==============================] - 45s 459ms/step - d_loss: -1.0264 - g_loss: 3.1975
Epoch 231/300
98/98 [==============================] - 45s 460ms/step - d_loss: -1.0250 - g_loss: 0.7918
Epoch 232/300
98/98 [==============================] - 45s 460ms/step - d_loss: -0.9574 - g_loss: 0.7262
Epoch 233/300
98/98 [==============================] - 45s 457ms/step - d_loss: -0.9317 - g_loss: -1.0323
Epoch 234/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9810 - g_loss: -0.4640
Epoch 235/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9978 - g_loss: 0.6541
Epoch 236/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0407 - g_loss: 1.0960
Epoch 237/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9414 - g_loss: 1.9438
Epoch 238/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9657 - g_loss: -0.6193
Epoch 239/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9495 - g_loss: 1.5019
Epoch 240/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9005 - g_loss: 2.8085
Epoch 241/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.8969 - g_loss: 1.0215
Epoch 242/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9676 - g_loss: 1.0533
Epoch 243/300
98/98 [==============================] - 44s 453ms/step - d_loss: -1.0075 - g_loss: 1.4819
Epoch 244/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9642 - g_loss: 2.1879
Epoch 245/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9639 - g_loss: 0.7668
Epoch 246/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9948 - g_loss: -0.5589
Epoch 247/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9146 - g_loss: 2.3022
Epoch 248/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9961 - g_loss: 0.5395
Epoch 249/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9905 - g_loss: 0.5777
Epoch 250/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.8771 - g_loss: 0.2384
Epoch 251/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9145 - g_loss: 0.8778
Epoch 252/300
98/98 [==============================] - ETA: 0s - d_loss: -0.9677 - g_loss: 1.6473
98/98 [==============================] - 47s 483ms/step - d_loss: -0.9694 - g_loss: 1.6683
Epoch 253/300
98/98 [==============================] - 44s 452ms/step - d_loss: -0.9292 - g_loss: 0.7109
Epoch 254/300
98/98 [==============================] - 44s 452ms/step - d_loss: -0.9489 - g_loss: 1.2075
Epoch 255/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.8722 - g_loss: 1.7124
Epoch 256/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9398 - g_loss: 0.4124
Epoch 257/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9371 - g_loss: 2.0218
Epoch 258/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9235 - g_loss: 1.3930
Epoch 259/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9945 - g_loss: 2.2995
Epoch 260/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9156 - g_loss: 2.4820
Epoch 261/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9120 - g_loss: 2.6750
Epoch 262/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9213 - g_loss: 2.3087
Epoch 263/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.8734 - g_loss: 1.2678
Epoch 264/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9443 - g_loss: 1.5327
Epoch 265/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9587 - g_loss: 2.6723
Epoch 266/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9363 - g_loss: 2.5637
Epoch 267/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9374 - g_loss: 1.9842
Epoch 268/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.8912 - g_loss: 1.9781
Epoch 269/300
98/98 [==============================] - 44s 453ms/step - d_loss: -0.9396 - g_loss: 1.4545
Epoch 270/300
98/98 [==============================] - 47s 477ms/step - d_loss: -0.9179 - g_loss: 0.9802
Epoch 271/300
98/98 [==============================] - 47s 483ms/step - d_loss: -0.9837 - g_loss: 0.6298
Epoch 272/300
98/98 [==============================] - 48s 485ms/step - d_loss: -0.9375 - g_loss: 0.2567
Epoch 273/300
98/98 [==============================] - 47s 483ms/step - d_loss: -0.9702 - g_loss: 0.6118
Epoch 274/300
98/98 [==============================] - 47s 484ms/step - d_loss: -1.0137 - g_loss: 0.4975
Epoch 275/300
98/98 [==============================] - 47s 483ms/step - d_loss: -0.9480 - g_loss: -0.3920
Epoch 276/300
98/98 [==============================] - 47s 483ms/step - d_loss: -0.8967 - g_loss: 1.1590
Epoch 277/300
98/98 [==============================] - 47s 476ms/step - d_loss: -0.8559 - g_loss: 1.7176
Epoch 278/300
98/98 [==============================] - 44s 452ms/step - d_loss: -0.8888 - g_loss: 1.5713
Epoch 279/300
98/98 [==============================] - 44s 452ms/step - d_loss: -0.9089 - g_loss: -0.1344
Epoch 280/300
98/98 [==============================] - 46s 470ms/step - d_loss: -0.9031 - g_loss: -0.1420
Epoch 281/300
98/98 [==============================] - 47s 481ms/step - d_loss: -0.8853 - g_loss: -0.2632
Epoch 282/300
98/98 [==============================] - 47s 482ms/step - d_loss: -0.9200 - g_loss: 1.7686
Epoch 283/300
98/98 [==============================] - 47s 481ms/step - d_loss: -0.8523 - g_loss: 1.4383
Epoch 284/300
98/98 [==============================] - 47s 481ms/step - d_loss: -0.9040 - g_loss: 1.0951
Epoch 285/300
98/98 [==============================] - 47s 482ms/step - d_loss: -0.9394 - g_loss: 2.6546
Epoch 286/300
98/98 [==============================] - 47s 483ms/step - d_loss: -0.9170 - g_loss: 0.1600
Epoch 287/300
98/98 [==============================] - 47s 483ms/step - d_loss: -0.9561 - g_loss: -0.1924
Epoch 288/300
98/98 [==============================] - 47s 482ms/step - d_loss: -0.8706 - g_loss: 1.3418
Epoch 289/300
98/98 [==============================] - 47s 481ms/step - d_loss: -0.9435 - g_loss: 1.6790
Epoch 290/300
98/98 [==============================] - 47s 484ms/step - d_loss: -0.9301 - g_loss: 1.1571
Epoch 291/300
98/98 [==============================] - 47s 481ms/step - d_loss: -0.8730 - g_loss: 1.8544
Epoch 292/300
98/98 [==============================] - 47s 482ms/step - d_loss: -0.8819 - g_loss: 1.2300
Epoch 293/300
98/98 [==============================] - 47s 483ms/step - d_loss: -0.8592 - g_loss: 1.3509
Epoch 294/300
98/98 [==============================] - 47s 482ms/step - d_loss: -0.9590 - g_loss: 0.9401
Epoch 295/300
98/98 [==============================] - 47s 482ms/step - d_loss: -0.9202 - g_loss: 2.5157
Epoch 296/300
98/98 [==============================] - 48s 485ms/step - d_loss: -0.9116 - g_loss: 2.0525
Epoch 297/300
98/98 [==============================] - 47s 481ms/step - d_loss: -0.8655 - g_loss: 0.4395
Epoch 298/300
98/98 [==============================] - 47s 482ms/step - d_loss: -0.9548 - g_loss: 2.3430
Epoch 299/300
98/98 [==============================] - 47s 482ms/step - d_loss: -0.9313 - g_loss: 2.3673
Epoch 300/300
98/98 [==============================] - 47s 482ms/step - d_loss: -0.8926 - g_loss: 1.3399

10.4 Evaluation¶

In [8]:
# story history object into dataframe
hist_df = pd.DataFrame(hist.history)

# using pandas dataframe to plot out learning curve
with plt.style.context('seaborn'):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8), tight_layout=True)
    hist_df.loc[:, ["d_loss", 'g_loss']].plot(ax=ax1, title=r'Learning Curve of Loss Function CE')
    plt.show()

10.5 Loss function evaluation¶

Test was smooth and loss was very stable even for high epochs

In [9]:
tf.keras.models.save_model(d_model , "./discriminator_wgan4.h5")
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
In [10]:
tf.keras.models.save_model(g_model , "./generator_wgan4.h5")
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
In [3]:
gen4=tf.keras.models.load_model("./Models/WGAN/generator_wgan4.h5")
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.

10.6 Visual Evaluation¶

In [4]:
random_latent_vectors = tf.random.normal(shape=(100, 128))
generated_images =gen4(random_latent_vectors)
generated_images -= -1
generated_images /= (1 - 0)
# Create a figure with 10 rows and 10 column

fig, axes = plt.subplots(10,10, figsize=(10, 10))
axes = axes.ravel()
for i in range(100):
    img = keras.preprocessing.image.array_to_img(generated_images[i])
    axes[i].imshow(img)
    axes[i].axis("off")

10.7 FID Evaluation¶

In [11]:
# import relevant libraries to calculate FID
import numpy as np
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from tensorflow.image import resize
from scipy.linalg import sqrtm
import math
from tqdm import tqdm
import tensorflow as tf
class GAN_FID:
    def __init__(self, batch_size, latent_dim, sample_size, buffer_size):
        # setting Hyperparameters
        self.BATCH_SIZE = batch_size
        self.LATENT_DIM = latent_dim
        self.SAMPLE_SIZE = sample_size
        self.BUFFER_SIZE = buffer_size

        # setting Constants
        self.INCEPTION_SHAPE = (299, 299, 3)
        self.INCEPTION = InceptionV3(include_top=False, pooling='avg', input_shape=self.INCEPTION_SHAPE)
        self.AUTO = tf.data.AUTOTUNE

    # method to set generator and training data
    def fit(self, generator, train_data):
        # setting generative model and original data used for training 
        self.GENERATOR = generator
        self.train_data = train_data

        # Preparing Real Images
        trainloader = tf.data.Dataset.from_tensor_slices((self.train_data))
        trainloader = (
            trainloader
            .shuffle(self.BUFFER_SIZE)
            .map(self.__resize_and_preprocess, num_parallel_calls=self.AUTO)
            .batch(self.BATCH_SIZE, num_parallel_calls=self.AUTO)
            .prefetch(self.AUTO)
        )
        self.trainloader = trainloader

        # Generate and prepare Synthetic Images (Fake)
        rand_labels = np.random.randint(low=0, high=10, size=self.SAMPLE_SIZE)
        rand_labels = to_categorical(rand_labels)

        random_latent_vectors = tf.random.normal(shape=(1000, self.LATENT_DIM))
        generated_images = generator(random_latent_vectors)
        

        genloader = tf.data.Dataset.from_tensor_slices(generated_images)
        genloader = (
            genloader
            .map(self.__resize_and_preprocess, num_parallel_calls=self.AUTO)
            .batch(self.BATCH_SIZE, num_parallel_calls=self.AUTO)
            .prefetch(self.AUTO)
        )
        self.genloader = genloader

        # prepare embeddings
        count = math.ceil(self.SAMPLE_SIZE/self.BATCH_SIZE)

        ## compute embeddings for real images
        print("Computing Real Image Embeddings")
        self.real_image_embeddings = self.__compute_embeddings(self.trainloader, count)

        ## compute embeddings for generated images
        print("Computing Generated Image Embeddings")
        self.generated_image_embeddings = self.__compute_embeddings(self.genloader, count)
        # assert self.real_image_embeddings.shape == self.generated_image_embeddings.shape, "Embeddings are not of the same size"
        print("Computed Embeddings\tReal Images Embedding Shape: {}\tGenerated Images Embedding Shape".format(
            self.real_image_embeddings.shape, 
            self.generated_image_embeddings.shape
        ))
    
    # method to produce evaluation results
    @tf.autograph.experimental.do_not_convert
    def evaluate(self):
        # calculate Frechet Inception Distance
        fid = self.__calculate_fid(self.real_image_embeddings, self.generated_image_embeddings)
        print('The computed FID score is:', fid)

        return fid

    # method to generate embeddings from inception model 
    def __compute_embeddings(self, dataloader, count):
        image_embeddings = []
        for _ in tqdm(range(count)):
            images = next(iter(dataloader))
            embeddings = self.INCEPTION.predict(images)
            image_embeddings.extend(embeddings)
        return np.array(image_embeddings)

    ## STATIC METHODS: these methods knows nothing about the class
    # static method to prepare the data before computing Inception Embeddings
    @staticmethod
    def __resize_and_preprocess(image):
        # image *= 255.0 # original image are scaled to [0, 1], scaling back to [0, 255]

        image = tf.image.convert_image_dtype(image, dtype=tf.float32, saturate=True)

    
        # .preprocess_input() expects an image of scale [0, 255]
        image = preprocess_input(image)
        # inception model expects an image of shape (None, 299, 299, 3)
        image = tf.image.resize(image, (299, 299), method='nearest')
        return image

    # static method to calculate frechet inception distance based on embeddings
    @staticmethod 
    def __calculate_fid(real_embeddings, generated_embeddings):
        # calculate mean and covariance statistics
        mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
        mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(generated_embeddings, rowvar=False)
        # calculate sum squared difference between means
        ssdiff = np.sum((mu1 - mu2)**2.0)
        # calculate sqrt of product between cov
        covmean = sqrtm(sigma1.dot(sigma2))
        # check and correct imaginary numbers from sqrt
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        # calculate score
        fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
        return fid
In [13]:
%%time
fid_class = GAN_FID(batch_size=512, latent_dim=128, sample_size=10000, buffer_size=1024)
fid_class.fit(generator=gen4, train_data=test_images)
fid_score = fid_class.evaluate()
Computing Real Image Embeddings
  0%|          | 0/20 [00:00<?, ?it/s]
16/16 [==============================] - 3s 146ms/step
  5%|▌         | 1/20 [00:03<01:13,  3.87s/it]
16/16 [==============================] - 2s 152ms/step
 10%|█         | 2/20 [00:06<00:59,  3.30s/it]
16/16 [==============================] - 2s 99ms/step
 15%|█▌        | 3/20 [00:08<00:47,  2.78s/it]
16/16 [==============================] - 1s 92ms/step
 20%|██        | 4/20 [00:10<00:39,  2.47s/it]
16/16 [==============================] - 2s 133ms/step
 25%|██▌       | 5/20 [00:13<00:37,  2.53s/it]
16/16 [==============================] - 3s 180ms/step
 30%|███       | 6/20 [00:16<00:39,  2.80s/it]
16/16 [==============================] - 2s 109ms/step
 35%|███▌      | 7/20 [00:19<00:34,  2.62s/it]
16/16 [==============================] - 1s 73ms/step
 40%|████      | 8/20 [00:20<00:27,  2.33s/it]
16/16 [==============================] - 1s 89ms/step
 45%|████▌     | 9/20 [00:22<00:24,  2.21s/it]
16/16 [==============================] - 2s 141ms/step
 50%|█████     | 10/20 [00:25<00:24,  2.44s/it]
16/16 [==============================] - 2s 125ms/step
 55%|█████▌    | 11/20 [00:28<00:22,  2.47s/it]
16/16 [==============================] - 2s 131ms/step
 60%|██████    | 12/20 [00:30<00:20,  2.53s/it]
16/16 [==============================] - 2s 145ms/step
 65%|██████▌   | 13/20 [00:33<00:18,  2.62s/it]
16/16 [==============================] - 2s 123ms/step
 70%|███████   | 14/20 [00:36<00:15,  2.59s/it]
16/16 [==============================] - 2s 129ms/step
 75%|███████▌  | 15/20 [00:38<00:12,  2.59s/it]
16/16 [==============================] - 2s 142ms/step
 80%|████████  | 16/20 [00:41<00:10,  2.65s/it]
16/16 [==============================] - 2s 143ms/step
 85%|████████▌ | 17/20 [00:44<00:08,  2.74s/it]
16/16 [==============================] - 2s 147ms/step
 90%|█████████ | 18/20 [00:47<00:05,  2.93s/it]
16/16 [==============================] - 3s 176ms/step
 95%|█████████▌| 19/20 [00:51<00:03,  3.06s/it]
16/16 [==============================] - 2s 151ms/step
100%|██████████| 20/20 [00:54<00:00,  2.72s/it]
Computing Generated Image Embeddings
  0%|          | 0/20 [00:00<?, ?it/s]
16/16 [==============================] - 2s 128ms/step
  5%|▌         | 1/20 [00:02<00:48,  2.56s/it]
16/16 [==============================] - 2s 144ms/step
 10%|█         | 2/20 [00:05<00:48,  2.71s/it]
16/16 [==============================] - 2s 126ms/step
 15%|█▌        | 3/20 [00:07<00:44,  2.64s/it]
16/16 [==============================] - 2s 155ms/step
 20%|██        | 4/20 [00:10<00:44,  2.80s/it]
16/16 [==============================] - 2s 146ms/step
 25%|██▌       | 5/20 [00:13<00:42,  2.84s/it]
16/16 [==============================] - 3s 178ms/step
 30%|███       | 6/20 [00:17<00:42,  3.07s/it]
16/16 [==============================] - 4s 236ms/step
 35%|███▌      | 7/20 [00:21<00:44,  3.45s/it]
16/16 [==============================] - 2s 156ms/step
 40%|████      | 8/20 [00:24<00:40,  3.35s/it]
16/16 [==============================] - 3s 184ms/step
 45%|████▌     | 9/20 [00:28<00:38,  3.53s/it]
16/16 [==============================] - 2s 129ms/step
 50%|█████     | 10/20 [00:31<00:32,  3.24s/it]
16/16 [==============================] - 7s 492ms/step
 55%|█████▌    | 11/20 [00:39<00:42,  4.70s/it]
16/16 [==============================] - 8s 526ms/step
 60%|██████    | 12/20 [00:48<00:48,  6.08s/it]
16/16 [==============================] - 3s 199ms/step
 65%|██████▌   | 13/20 [00:52<00:37,  5.36s/it]
16/16 [==============================] - 4s 236ms/step
 70%|███████   | 14/20 [00:56<00:30,  5.13s/it]
16/16 [==============================] - 3s 167ms/step
 75%|███████▌  | 15/20 [01:00<00:22,  4.56s/it]
16/16 [==============================] - 4s 242ms/step
 80%|████████  | 16/20 [01:04<00:17,  4.49s/it]
16/16 [==============================] - 2s 140ms/step
 85%|████████▌ | 17/20 [01:07<00:11,  4.00s/it]
16/16 [==============================] - 2s 135ms/step
 90%|█████████ | 18/20 [01:09<00:07,  3.63s/it]
16/16 [==============================] - 2s 152ms/step
 95%|█████████▌| 19/20 [01:13<00:03,  3.61s/it]
16/16 [==============================] - 7s 445ms/step
100%|██████████| 20/20 [01:22<00:00,  4.11s/it]
Computed Embeddings	Real Images Embedding Shape: (10240, 2048)	Generated Images Embedding Shape
The computed FID score is: 5.453099102467402
CPU times: total: 4min 46s
Wall time: 2min 24s

10.8 Conclusion of WGAN¶

WGAN is a better GAN than GAN because it has a better loss function and it is more stable. However, the images are not superb and I would say the images are kind of the same or maybe just slightly better than the original DCGAN images. I would say it is too computationally expensive to continue exploring WGAN and even trying out Data augmentation which will take a lot longer than the current Spectral Normalization CGAN that i did before this. Therefore I will submit the images generated by Spectral Normalization CGAN as my final submission.